{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test Auto Generate Annotations\n",
    "\n",
    "这个notebook用于测试 auto_generate_annotations 方法。该方法位于 `src/williamtoolbox/annotation.py`。\n",
    "\n",
    "## 方法说明\n",
    "\n",
    "auto_generate_annotations 方法用于自动对文档生成批注。它需要以下参数：\n",
    "- rag_name: RAG服务的名称\n",
    "- doc: 文档内容"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "from williamtoolbox.annotation import auto_generate_annotations\n",
    "import os\n",
    "\n",
    "# 设置测试用的文档内容\n",
    "test_doc = \"\"\"\n",
    "机器学习是人工智能的一个分支，它主要研究如何使用计算机模拟或实现人类的学习行为。\n",
    "机器学习算法可以根据数据自动学习规律，并对新数据进行预测。\n",
    "\n",
    "深度学习是机器学习的一个重要方向，它通过构建具有多层处理层的人工神经网络，\n",
    "可以学习数据中的抽象特征表示。深度学习在图像识别、语音识别等领域取得了突破性进展。\n",
    "\n",
    "强化学习是另一个重要的机器学习方法，它通过agent与环境的交互来学习最优策略。\n",
    "在游戏、机器人控制等领域有广泛应用。\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备RAG服务\n",
    "\n",
    "在测试之前，我们需要确保有一个可用的RAG服务。这里我们使用命令行启动一个RAG服务："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 确保我们在正确的工作目录\n",
    "!pwd\n",
    "\n",
    "# 启动RAG服务\n",
    "!byzerllm deploy --pretrained_model_type custom/bge \\\n",
    "--cpus_per_worker 0.001 \\\n",
    "--gpus_per_worker 0 \\\n",
    "--worker_concurrency 10 \\\n",
    "--model_path ~/.auto-coder/storage/models/AI-ModelScope/bge-large-zh \\\n",
    "--infer_backend transformers \\\n",
    "--num_workers 1 \\\n",
    "--model emb\n",
    "\n",
    "# 启动一个RAG服务示例\n",
    "!auto-coder.rag serve \\\n",
    "--model deepseek_chat \\\n",
    "--tokenizer_path ~/.auto-coder/storage/models/deepseek-ai/deepseek-llm-67b-chat \\\n",
    "--doc_dir /tmp/test-docs \\\n",
    "--port 8001 \\\n",
    "--host 0.0.0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 auto_generate_annotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def test_auto_generate():\n",
    "    try:\n",
    "        # 调用方法生成批注\n",
    "        result = await auto_generate_annotations(\"test_rag\", test_doc)\n",
    "        \n",
    "        # 打印结果\n",
    "        print(\"文档内容:\")\n",
    "        print(result.doc_text)\n",
    "        print(\"\\n生成的批注:\")\n",
    "        for i, annotation in enumerate(result.annotations, 1):\n",
    "            print(f\"\\n批注 {i}:\")\n",
    "            print(f\"文本: {annotation.text}\")\n",
    "            print(f\"批注: {annotation.comment}\")\n",
    "            \n",
    "    except Exception as e:\n",
    "        print(f\"测试失败: {str(e)}\")\n",
    "\n",
    "# 运行测试\n",
    "await test_auto_generate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试多个文档场景"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试具有多个段落和不同主题的文档\n",
    "test_doc_2 = \"\"\"\n",
    "自然语言处理（NLP）是人工智能中一个重要的研究方向。\n",
    "它致力于让计算机理解和处理人类语言。近年来，基于Transformer的模型\n",
    "如BERT、GPT等在多个NLP任务上取得了显著进展。\n",
    "\n",
    "计算机视觉是另一个重要领域，主要研究如何让计算机理解图像和视频内容。\n",
    "卷积神经网络（CNN）是计算机视觉中最常用的深度学习模型之一。\n",
    "\"\"\"\n",
    "\n",
    "async def test_multiple_docs():\n",
    "    try:\n",
    "        # 对多个文档生成批注\n",
    "        results = await asyncio.gather(\n",
    "            auto_generate_annotations(\"test_rag\", test_doc),\n",
    "            auto_generate_annotations(\"test_rag\", test_doc_2)\n",
    "        )\n",
    "        \n",
    "        # 打印结果\n",
    "        for i, result in enumerate(results, 1):\n",
    "            print(f\"\\n文档 {i} 的批注:\")\n",
    "            for j, annotation in enumerate(result.annotations, 1):\n",
    "                print(f\"\\n批注 {j}:\")\n",
    "                print(f\"文本: {annotation.text}\")\n",
    "                print(f\"批注: {annotation.comment}\")\n",
    "                \n",
    "    except Exception as e:\n",
    "        print(f\"测试失败: {str(e)}\")\n",
    "\n",
    "# 运行多文档测试\n",
    "await test_multiple_docs()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试错误处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def test_error_cases():\n",
    "    try:\n",
    "        # 测试空文档\n",
    "        print(\"测试空文档:\")\n",
    "        result = await auto_generate_annotations(\"test_rag\", \"\")\n",
    "        print(\"结果:\", result)\n",
    "    except Exception as e:\n",
    "        print(f\"预期的错误: {str(e)}\")\n",
    "        \n",
    "    try:\n",
    "        # 测试不存在的RAG服务\n",
    "        print(\"\\n测试不存在的RAG服务:\")\n",
    "        result = await auto_generate_annotations(\"non_existent_rag\", test_doc)\n",
    "        print(\"结果:\", result)\n",
    "    except Exception as e:\n",
    "        print(f\"预期的错误: {str(e)}\")\n",
    "\n",
    "# 运行错误测试\n",
    "await test_error_cases()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 chat_with_rag\n",
    "\n",
    "接下来我们要测试 chat_with_rag 方法。这个方法用于与指定的 RAG 服务进行聊天交互。我们将使用与 auto_generate_annotations 中相似的 RAG 调用方式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from williamtoolbox.annotation import chat_with_rag\n",
    "\n",
    "async def test_chat_with_rag():\n",
    "    try:\n",
    "        # 准备测试消息\n",
    "        messages = [\n",
    "            {\"role\": \"user\", \"content\": \"你能告诉我机器学习是什么吗？\"}\n",
    "        ]\n",
    "        \n",
    "        # 调用 chat_with_rag 方法\n",
    "        response = await chat_with_rag(\"test_rag\", messages)\n",
    "        \n",
    "        print(\"RAG 回答:\")\n",
    "        print(response)\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"测试失败: {str(e)}\")\n",
    "\n",
    "# 运行测试\n",
    "await test_chat_with_rag()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def test_chat_with_rag_conversation():\n",
    "    try:\n",
    "        # 准备多轮对话消息\n",
    "        messages = [\n",
    "            {\"role\": \"user\", \"content\": \"什么是深度学习？\"},\n",
    "            {\"role\": \"assistant\", \"content\": \"深度学习是机器学习的一个分支，主要使用多层神经网络进行学习。\"},\n",
    "            {\"role\": \"user\", \"content\": \"那神经网络是什么？\"}\n",
    "        ]\n",
    "        \n",
    "        # 调用 chat_with_rag 方法进行多轮对话\n",
    "        response = await chat_with_rag(\"批注知识库\", messages)\n",
    "        \n",
    "        print(\"多轮对话测试:\")\n",
    "        print(\"问题历史:\")\n",
    "        for msg in messages:\n",
    "            print(f\"{msg['role']}: {msg['content']}\")\n",
    "        print(\"\\nRAG 回答:\")\n",
    "        print(response)\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"测试失败: {str(e)}\")\n",
    "\n",
    "# 运行多轮对话测试\n",
    "await test_chat_with_rag_conversation()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def test_chat_with_rag_errors():\n",
    "    try:\n",
    "        # 测试空消息列表\n",
    "        print(\"测试空消息列表:\")\n",
    "        result = await chat_with_rag(\"test_rag\", [])\n",
    "        print(\"结果:\", result)\n",
    "    except Exception as e:\n",
    "        print(f\"预期的错误: {str(e)}\")\n",
    "        \n",
    "    try:\n",
    "        # 测试不存在的RAG服务\n",
    "        print(\"\\n测试不存在的RAG服务:\")\n",
    "        messages = [{\"role\": \"user\", \"content\": \"测试消息\"}]\n",
    "        result = await chat_with_rag(\"non_existent_rag\", messages)\n",
    "        print(\"结果:\", result)\n",
    "    except Exception as e:\n",
    "        print(f\"预期的错误: {str(e)}\")\n",
    "\n",
    "# 运行错误测试\n",
    "await test_chat_with_rag_errors()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
